a3b7a2
@@ -68,6 +68,9 @@
 import org.apache.hadoop.hive.ql.plan.OperatorDesc;
 import org.apache.hadoop.hive.ql.plan.Statistics;
 import org.apache.hadoop.hive.ql.stats.StatsUtils;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory;
+import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo;
+import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
 import org.apache.hadoop.util.ReflectionUtils;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -88,6 +91,8 @@
   private static final Logger LOG = LoggerFactory.getLogger(ConvertJoinMapJoin.class.getName());
   public float hashTableLoadFactor;
   private long maxJoinMemory;
+  private HashMapDataStructureType hashMapDataStructure;
+  private boolean fastHashTableAvailable;
 
   @Override
   /*
@@ -102,6 +107,7 @@
     OptimizeTezProcContext context = (OptimizeTezProcContext) procCtx;
 
     hashTableLoadFactor = context.conf.getFloatVar(ConfVars.HIVEHASHTABLELOADFACTOR);
+    fastHashTableAvailable = context.conf.getBoolVar(ConfVars.HIVE_VECTORIZATION_MAPJOIN_NATIVE_FAST_HASHTABLE_ENABLED);
 
     JoinOperator joinOp = (JoinOperator) nd;
     // adjust noconditional task size threshold for LLAP
@@ -116,6 +122,9 @@
 
     LOG.info("maxJoinMemory: {}", maxJoinMemory);
 
+    hashMapDataStructure = HashMapDataStructureType.of(joinOp.getConf());
+
+
     TezBucketJoinProcCtx tezBucketJoinProcCtx = new TezBucketJoinProcCtx(context.conf);
     boolean hiveConvertJoin = context.conf.getBoolVar(HiveConf.ConfVars.HIVECONVERTJOIN) &
             !context.parseContext.getDisableMapJoin();
@@ -193,6 +202,32 @@
     return null;
   }
 
+  private enum HashMapDataStructureType {
+    COMPOSITE_KEYED, LONG_KEYED;
+
+    public static HashMapDataStructureType of(JoinDesc conf) {
+      ExprNodeDesc[][] keys = conf.getJoinKeys();
+      if (keys != null && keys[0].length == 1) {
+        TypeInfo typeInfo = keys[0][0].getTypeInfo();
+        if (typeInfo instanceof PrimitiveTypeInfo) {
+          PrimitiveTypeInfo pti = ((PrimitiveTypeInfo) typeInfo);
+          PrimitiveCategory pCat = pti.getPrimitiveCategory();
+          switch (pCat) {
+          case BOOLEAN:
+          case BYTE:
+          case SHORT:
+          case INT:
+          case LONG:
+            return HashMapDataStructureType.LONG_KEYED;
+          default:
+            break;
+          }
+        }
+      }
+      return HashMapDataStructureType.COMPOSITE_KEYED;
+    }
+  }
+
   private boolean selectJoinForLlap(OptimizeTezProcContext context, JoinOperator joinOp,
                           TezBucketJoinProcCtx tezBucketJoinProcCtx,
                           LlapClusterStateForCompile llapInfo,
@@ -239,6 +274,11 @@
private boolean selectJoinForLlap(OptimizeTezProcContext context, JoinOperator j
     LOG.info("Cost of Bucket Map Join : numNodes = " + numNodes + " total small table size = "
     + totalSize + " networkCostMJ = " + networkCostMJ);
 
+    if (totalSize <= maxJoinMemory) {
+      // mapjoin is applicable; don't try the below algos..
+      return false;
+    }
+
     if (networkCostDPHJ < networkCostMJ) {
       LOG.info("Dynamically partitioned Hash Join chosen");
       return convertJoinDynamicPartitionedHashJoin(joinOp, context);
@@ -252,17 +292,32 @@
private boolean selectJoinForLlap(OptimizeTezProcContext context, JoinOperator j
   }
 
   public long computeOnlineDataSize(Statistics statistics) {
-    return computeOnlineDataSizeFast3(statistics);
+    if (fastHashTableAvailable) {
+      return computeOnlineDataSizeFast(statistics);
+    } else {
+      return computeOnlineDataSizeOptimized(statistics);
+    }
+  }
+
+  public long computeOnlineDataSizeFast(Statistics statistics) {
+    switch (hashMapDataStructure) {
+    case LONG_KEYED:
+      return computeOnlineDataSizeFastLongKeyed(statistics);
+    case COMPOSITE_KEYED:
+      return computeOnlineDataSizeFastCompositeKeyed(statistics);
+    default:
+      throw new RuntimeException("invalid mode");
+    }
   }
 
-  public long computeOnlineDataSizeFast2(Statistics statistics) {
+  public long computeOnlineDataSizeFastLongKeyed(Statistics statistics) {
     return computeOnlineDataSizeGeneric(statistics,
         -8, // the long key is stored in a slot
         2 * 8 // maintenance structure consists of 2 longs
     );
   }
 
-  public long computeOnlineDataSizeFast3(Statistics statistics) {
+  public long computeOnlineDataSizeFastCompositeKeyed(Statistics statistics) {
     return computeOnlineDataSizeGeneric(statistics,
         5 + 4, // list header ; value length stored as vint
         8 // maintenance structure consists of 1 long
@@ -1024,7 +1079,7 @@
public int getMapJoinConversionPos(JoinOperator joinOp, OptimizeTezProcContext c
     // We store the total memory that this MapJoin is going to use,
     // which is calculated as totalSize/buckets, with totalSize
     // equal to sum of small tables size.
-    joinOp.getConf().setInMemoryDataSize(totalSize/buckets);
+    joinOp.getConf().setInMemoryDataSize(totalSize / buckets);
 
     return bigTablePosition;
   }
